import numpy as np
import tensorflow as tf
from sklearn.decomposition import SparseCoder
import timeit

PATCH_SIZE = 32
#Size of random sample from entire input space
BATCH_SIZE = 5000

def responses(samples, scBasis, sparsity, nonNeg, icaFilters):
	samples = samples.reshape((samples.shape[0], PATCH_SIZE, PATCH_SIZE, 1))

	patchCount = samples.shape[0]

	v1Simple = np.zeros((patchCount, 11, 11, 3, 12, 2))
	v1Complex = np.zeros((patchCount, 4356))
	anglesFile = np.zeros((patchCount, 11, 11, 3, 12, 1))

	loopCount = int(np.ceil(patchCount / BATCH_SIZE))

	with tf.device('/gpu:0'):
		for i in range(loopCount):

			bStart = i * BATCH_SIZE
			bEnd = bStart + BATCH_SIZE

			batch = samples[bStart:bEnd, :, :, :]

			filtersTensor = tf.convert_to_tensor(np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1)))

			v1Responses = tf.nn.conv2d(batch, filtersTensor, [1, 3, 3, 1], 'SAME')

			v1Responses = tf.reshape(v1Responses, [v1Responses.shape[0], 11, 11, 3, 12, 2])

			pair0, pair1 = tf.split(v1Responses, num_or_size_splits = 2, axis = -1)

			angles = tf.atan2(pair1, pair0)

			v1CResponses = tf.sqrt(tf.reduce_sum(tf.square(v1Responses), axis = -1))

			v1CResponses = tf.reshape(v1CResponses, [v1CResponses.shape[0], -1])

			v1Simple[bStart:bEnd, :] = v1Responses[:].numpy()
			v1Complex[bStart:bEnd, :] = v1CResponses[:].numpy()
			anglesFile[bStart:bEnd, :, :, :, :, :] = angles[:].numpy()

#	val = np.min(v1Complex) - 1.0
#	v1Complex = v1Complex.reshape((patchCount, 11, 11, 3, 12))
#	v1Complex[:, 5, 5, :, :] = val
#	v1Complex[:, 6, 5, :, :] = val
#	v1Complex[:, 5, 6, :, :] = val
#	v1Complex[:, 6, 6, :, :] = val
#	v1Complex = v1Complex.reshape((patchCount, -1))

	v1ComplexMean = np.mean(v1Complex, axis = 1, keepdims = True)

	v1Complex -= v1ComplexMean

	forwardPCA = np.load('forwardPCA.npy')

	pcaTransformed = np.dot(v1Complex, forwardPCA.T)

	sc = SparseCoder(scBasis, transform_algorithm = 'lasso_cd', transform_alpha = sparsity, positive_code = nonNeg)

	scResp = sc.transform(pcaTransformed)

	icaResp = np.dot(pcaTransformed, icaFilters.T)

	icaResp[np.where(icaResp < 0.0)] = 0.0

	return scResp, icaResp, v1Simple, v1ComplexMean, anglesFile

def responsesV1(v1, scBasis, sparsity, nonNeg, icaFilters):
	patchCount = v1.shape[0]

	v1Complex = np.zeros((patchCount, 4356))
	anglesFile = np.zeros((patchCount, 11, 11, 3, 12, 1))

	loopCount = int(np.ceil(patchCount / BATCH_SIZE))

	with tf.device('/gpu:0'):
		for i in range(loopCount):

			bStart = i * BATCH_SIZE
			bEnd = bStart + BATCH_SIZE

			v1Responses = v1[bStart:bEnd, :]

			filtersTensor = tf.convert_to_tensor(np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1)))

			pair0, pair1 = tf.split(v1Responses, num_or_size_splits = 2, axis = -1)

			angles = tf.atan2(pair1, pair0)

			v1CResponses = tf.sqrt(tf.reduce_sum(tf.square(v1Responses), axis = -1))

			v1CResponses = tf.reshape(v1CResponses, [v1CResponses.shape[0], -1])

			v1Complex[bStart:bEnd, :] = v1CResponses[:].numpy()
			anglesFile[bStart:bEnd, :, :, :, :, :] = angles[:].numpy()

	v1ComplexMean = np.mean(v1Complex, axis = 1, keepdims = True)

	v1Complex -= v1ComplexMean

	forwardPCA = np.load('forwardPCA.npy')

	pcaTransformed = np.dot(v1Complex, forwardPCA.T)

	sc = SparseCoder(scBasis, transform_algorithm = 'lasso_cd', transform_alpha = sparsity, positive_code = nonNeg)

	scResp = sc.transform(pcaTransformed)

	icaResp = np.dot(pcaTransformed, icaFilters.T)

	icaResp[np.where(icaResp < 0.0)] = 0.0

	return scResp, icaResp, v1ComplexMean, anglesFile

def responsesPCAV1C(samples):
	samples = samples.reshape((samples.shape[0], PATCH_SIZE, PATCH_SIZE, 1))

	patchCount = samples.shape[0]

	v1Simple = np.zeros((patchCount, 11, 11, 3, 12, 2))
	v1Complex = np.zeros((patchCount, 4356))
	anglesFile = np.zeros((patchCount, 11, 11, 3, 12, 1))

	loopCount = int(np.ceil(patchCount / BATCH_SIZE))

	with tf.device('/gpu:0'):
		for i in range(loopCount):

			bStart = i * BATCH_SIZE
			bEnd = bStart + BATCH_SIZE

			batch = samples[bStart:bEnd, :, :, :]

			filtersTensor = tf.convert_to_tensor(np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1)))

			v1Responses = tf.nn.conv2d(batch, filtersTensor, [1, 3, 3, 1], 'SAME')

			v1Responses = tf.reshape(v1Responses, [v1Responses.shape[0], 11, 11, 3, 12, 2])

			pair0, pair1 = tf.split(v1Responses, num_or_size_splits = 2, axis = -1)

			angles = tf.atan2(pair1, pair0)

			v1CResponses = tf.sqrt(tf.reduce_sum(tf.square(v1Responses), axis = -1))

			v1CResponses = tf.reshape(v1CResponses, [v1CResponses.shape[0], -1])

			v1Simple[bStart:bEnd, :] = v1Responses[:].numpy()
			v1Complex[bStart:bEnd, :] = v1CResponses[:].numpy()
			anglesFile[bStart:bEnd, :, :, :, :, :] = angles[:].numpy()

#	val = np.min(v1Complex) - 1.0
#	v1Complex = v1Complex.reshape((patchCount, 11, 11, 3, 12))
#	v1Complex[:, 5, 5, :, :] = val
#	v1Complex[:, 6, 5, :, :] = val
#	v1Complex[:, 5, 6, :, :] = val
#	v1Complex[:, 6, 6, :, :] = val
#	v1Complex = v1Complex.reshape((patchCount, -1))

	v1CCopy = v1Complex.copy()

	v1ComplexMean = np.mean(v1Complex, axis = 1, keepdims = True)

	v1Complex -= v1ComplexMean

	forwardPCA = np.load('forwardPCA.npy')

	pcaTransformed = np.dot(v1Complex, forwardPCA.T)

	return pcaTransformed, v1CCopy

def responsesCropV1(samples, scBasis, sparsity, nonNeg, icaFilters):
	samples = samples.reshape((samples.shape[0], PATCH_SIZE, PATCH_SIZE, 1))

	patchCount = samples.shape[0]

	v1Simple = np.zeros((patchCount, 11, 11, 3, 12, 2))
	v1SimpleCropped = np.zeros((patchCount, 11, 11, 3, 12, 2))
	v1Complex = np.zeros((patchCount, 4356))
	anglesFile = np.zeros((patchCount, 11, 11, 3, 12, 1))

	loopCount = int(np.ceil(patchCount / BATCH_SIZE))

	with tf.device('/gpu:0'):
		for i in range(loopCount):

			bStart = i * BATCH_SIZE
			bEnd = bStart + BATCH_SIZE

			batch = samples[bStart:bEnd, :, :, :]

			filtersTensor = tf.convert_to_tensor(np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1)))

			v1Responses = tf.nn.conv2d(batch, filtersTensor, [1, 3, 3, 1], 'SAME')

			v1Responses = tf.reshape(v1Responses, [v1Responses.shape[0], 11, 11, 3, 12, 2])

			v1Simple[bStart:bEnd, :] = v1Responses[:].numpy()

			v1Np = v1Responses.numpy()
			v1Np[:, 3:, :, :, :, :] = 0.0
			v1Responses = tf.constant(v1Np)

			pair0, pair1 = tf.split(v1Responses, num_or_size_splits = 2, axis = -1)

			angles = tf.atan2(pair1, pair0)

			v1CResponses = tf.sqrt(tf.reduce_sum(tf.square(v1Responses), axis = -1))

			v1CResponses = tf.reshape(v1CResponses, [v1CResponses.shape[0], -1])

			v1SimpleCropped[bStart:bEnd, :] = v1Responses[:].numpy()
			v1Complex[bStart:bEnd, :] = v1CResponses[:].numpy()
			anglesFile[bStart:bEnd, :, :, :, :, :] = angles[:].numpy()

	v1ComplexMean = np.mean(v1Complex, axis = 1, keepdims = True)

	v1Complex -= v1ComplexMean

	forwardPCA = np.load('forwardPCA.npy')

	pcaTransformed = np.dot(v1Complex, forwardPCA.T)

	sc = SparseCoder(scBasis, transform_algorithm = 'lasso_cd', transform_alpha = sparsity, positive_code = nonNeg)

	scResp = sc.transform(pcaTransformed)

	icaResp = np.dot(pcaTransformed, icaFilters.T)

	icaResp[np.where(icaResp < 0.0)] = 0.0

	return scResp, icaResp, v1Simple, v1SimpleCropped, v1ComplexMean, anglesFile

def reconstruct(basis, codes, v1cMean, angles):
	filters = np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1))

	invPCA = np.load('inversePCA.npy')

	super = np.dot(codes, basis)
	super = np.dot(super, invPCA)

	super += v1cMean[:super.shape[0]]

	super = super.reshape((super.shape[0], 11, 11, 3, 12, 1))

	quadPair0 = super * np.cos(angles)
	quadPair1 = super * np.sin(angles)

	v1 = np.concatenate((quadPair0, quadPair1), axis = -1)

	with tf.device('/gpu:0'):
		v1Tensor = tf.convert_to_tensor(v1, dtype = tf.float32)
		v1Tensor = tf.reshape(v1Tensor, [v1Tensor.shape[0], 11, 11, -1])

		gaborsTensor = tf.convert_to_tensor(filters, dtype = tf.float32)
		gaborsTensor = tf.reshape(gaborsTensor, [12, 12, 1, -1])

		reconst = tf.nn.conv2d_transpose(v1Tensor, gaborsTensor, [v1Tensor.shape[0], 32, 32, 1], [1, 3, 3, 1], padding = 'SAME')

		return reconst.numpy()[:, :, :, 0]

def reconstructV1C(v1C, v1cMean, angles):
	filters = np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1))

	v1C += v1cMean

	v1C = v1C.reshape((v1C.shape[0], 11, 11, 3, 12, 1))

	quadPair0 = v1C * np.cos(angles)
	quadPair1 = v1C * np.sin(angles)

	v1 = np.concatenate((quadPair0, quadPair1), axis = -1)

	with tf.device('/gpu:0'):
		v1Tensor = tf.convert_to_tensor(v1, dtype = tf.float32)
		v1Tensor = tf.reshape(v1Tensor, [v1Tensor.shape[0], 11, 11, -1])

		gaborsTensor = tf.convert_to_tensor(filters, dtype = tf.float32)
		gaborsTensor = tf.reshape(gaborsTensor, [12, 12, 1, -1])

		reconst = tf.nn.conv2d_transpose(v1Tensor, gaborsTensor, [v1Tensor.shape[0], 32, 32, 1], [1, 3, 3, 1], padding = 'SAME')

		return reconst.numpy()[:, :, :, 0]

def reconstructPCA(pcaTransformed, v1cMean, angles):
	filters = np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1))

	invPCA = np.load('inversePCA.npy')

	pcaTransformed = np.dot(pcaTransformed, invPCA)

	pcaTransformed += v1cMean[:pcaTransformed.shape[0]]

	pcaTransformed = pcaTransformed.reshape((pcaTransformed.shape[0], 11, 11, 3, 12, 1))

	quadPair0 = pcaTransformed * np.cos(angles)
	quadPair1 = pcaTransformed * np.sin(angles)

	v1 = np.concatenate((quadPair0, quadPair1), axis = -1)

	with tf.device('/gpu:0'):
		v1Tensor = tf.convert_to_tensor(v1, dtype = tf.float32)
		v1Tensor = tf.reshape(v1Tensor, [v1Tensor.shape[0], 11, 11, -1])

		gaborsTensor = tf.convert_to_tensor(filters, dtype = tf.float32)
		gaborsTensor = tf.reshape(gaborsTensor, [12, 12, 1, -1])

		reconst = tf.nn.conv2d_transpose(v1Tensor, gaborsTensor, [v1Tensor.shape[0], 32, 32, 1], [1, 3, 3, 1], padding = 'SAME')

		return reconst.numpy()[:, :, :, 0]

def reconstructV1(v1):
	filters = np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1))

	with tf.device('/gpu:0'):
		v1Tensor = tf.convert_to_tensor(v1, dtype = tf.float32)
		v1Tensor = tf.reshape(v1Tensor, [v1Tensor.shape[0], 11, 11, -1])

		gaborsTensor = tf.convert_to_tensor(filters, dtype = tf.float32)

		reconst = tf.nn.conv2d_transpose(v1Tensor, gaborsTensor, [v1Tensor.shape[0], 32, 32, 1], [1, 3, 3, 1], padding = 'SAME')

		return reconst.numpy()[:, :, :, 0]

def reconstructForV1(basis, codes, v1cMean, angles):
	filters = np.load('gabors.npy').astype(np.float32).reshape((12, 12, 1, -1))

	invPCA = np.load('inversePCA.npy')

	super = np.dot(codes, basis)
	super = np.dot(super, invPCA)

	super += v1cMean[:super.shape[0]]

	super = super.reshape((super.shape[0], 11, 11, 3, 12, 1))

	quadPair0 = super * np.cos(angles)
	quadPair1 = super * np.sin(angles)

	v1 = np.concatenate((quadPair0, quadPair1), axis = -1)

	return v1
